from __future__ import print_function
import argparse
import os
import pandas as pd
import glob

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
parser.add_argument('--path', type=str, required=True)
args = parser.parse_args()

df = pd.DataFrame(columns=["Model", "Epoch", "Iterations", "Heads", "Rules", \
                           "Dimension", "Attention Dimension", "Question Injection", "Recurrence", \
                           "Ternary Test", "Binary Test", "Unary Test", \
                           "Ternary Train", "Binary Train", "Unary Train"])

files = glob.glob(f'{args.path}/*')
for file in files:
    name = file.split('/')[-1]
    if "recurrent" in name:
        recurrent="True"
    else:
        recurrent="False"

    if "inject-ques" in name:
        inj = "True"
    else:
        inj = "False"

    name_splits = name.split('_')
    model = name_splits[0]
    iterations = name_splits[1]
    dim = name_splits[2]
    att_dim = name_splits[3]
    heads = name_splits[4]
    rules = name_splits[5]

    with open(f'{file}/log.csv', 'r') as f:
        data = f.read().split('\n')[:-1]

    for epoch in range(10,110,10):
        line = data[epoch].split(',')
        if epoch != int(line[0]):
            print("Not compatible ... Error")
            print(name)
            exit()

        unary_train = float(line[3])
        unary_test = float(line[6])
        binary_train = float(line[2])
        binary_test = float(line[5])
        ternary_train = float(line[1])
        ternary_test = float(line[4])

        df.loc[-1] = [model, epoch, iterations, heads, rules, dim, att_dim, inj, recurrent, \
                      ternary_test, binary_test, unary_test, ternary_train, binary_train, unary_train]
        df.index = df.index + 1

ternary_sort = df.sort_values(by=['Ternary Test'])
binary_sort = df.sort_values(by=['Binary Test'])

print(ternary_sort.loc[ternary_sort['Epoch'] == 20])
# exit()
#
# df = pd.read_csv(args.path)
# if "train_acc_ternary.1" in df.columns:
#     df = df.rename(columns={"train_acc_ternary.1": "test_acc_ternary"})
#
# print(df)
# print()
# print("Final Accuracies: ")
# print(df.iloc[-1].to_frame().T)
# print()
# print("Peak Accuracies: ")
# print(df.max().to_frame().T)
